{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Code for **\"AlexNet inversion\"** figure from the main paper and **\"VGG inversion\"** from supmat."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Import libs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from __future__ import print_function\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "\n",
    "import argparse\n",
    "import os\n",
    "#os.environ['CUDA_VISIBLE_DEVICES'] = '0'\n",
    "\n",
    "import numpy as np\n",
    "from models import *\n",
    "\n",
    "import torch\n",
    "import torch.optim\n",
    "\n",
    "from utils.feature_inversion_utils import *\n",
    "from utils.perceptual_loss.perceptual_loss import get_pretrained_net\n",
    "from utils.common_utils import *\n",
    "\n",
    "torch.backends.cudnn.enabled = True\n",
    "torch.backends.cudnn.benchmark =True\n",
    "dtype = torch.cuda.FloatTensor\n",
    "\n",
    "PLOT = True\n",
    "fname = './data/feature_inversion/building.jpg'\n",
    "\n",
    "pretrained_net = 'alexnet_caffe' # 'vgg19_caffe'\n",
    "layers_to_use = 'fc6' # comma-separated string of layer names e.g. 'fc6,fc7'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Setup pretrained net"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cnn = get_pretrained_net(pretrained_net).type(dtype)\n",
    "\n",
    "opt_content = {'layers': layers_to_use, 'what':'features'}\n",
    "\n",
    "# Remove the layers we don't need \n",
    "keys = [x for x in cnn._modules.keys()]\n",
    "max_idx = max(keys.index(x) for x in opt_content['layers'].split(','))\n",
    "for k in keys[max_idx+1:]:\n",
    "    cnn._modules.pop(k)\n",
    "    \n",
    "print(cnn)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Target imsize \n",
    "imsize = 227 if pretrained_net == 'alexnet' else 224\n",
    "\n",
    "# Something divisible by a power of two\n",
    "imsize_net = 256\n",
    "\n",
    "# VGG and Alexnet need input to be correctly normalized\n",
    "preprocess, deprocess = get_preprocessor(imsize), get_deprocessor()\n",
    "\n",
    "\n",
    "img_content_pil, img_content_np  = get_image(fname, imsize)\n",
    "img_content_prerocessed = preprocess(img_content_pil)[None,:].type(dtype)\n",
    "\n",
    "img_content_pil"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Setup matcher and net"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "matcher_content = get_matcher(cnn, opt_content)\n",
    "\n",
    "matcher_content.mode = 'store'\n",
    "cnn(img_content_prerocessed);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "INPUT = 'noise'\n",
    "pad = 'zero' # 'refection'\n",
    "OPT_OVER = 'net' #'net,input'\n",
    "OPTIMIZER = 'adam' # 'LBFGS'\n",
    "LR = 0.001\n",
    "\n",
    "num_iter = 3100\n",
    "\n",
    "input_depth = 32\n",
    "net_input = get_noise(input_depth, INPUT, imsize_net).type(dtype).detach()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "net = skip(input_depth, 3, num_channels_down = [16, 32, 64, 128, 128, 128],\n",
    "                           num_channels_up =   [16, 32, 64, 128, 128, 128],\n",
    "                           num_channels_skip = [4, 4, 4, 4, 4, 4],   \n",
    "                           filter_size_down = [7, 7, 5, 5, 3, 3], filter_size_up = [7, 7, 5, 5, 3, 3], \n",
    "                           upsample_mode='nearest', downsample_mode='avg',\n",
    "                           need_sigmoid=True, pad=pad, act_fun='LeakyReLU').type(dtype)\n",
    "\n",
    "# Compute number of parameters\n",
    "s  = sum(np.prod(list(p.size())) for p in net.parameters())\n",
    "print ('Number of params: %d' % s)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Optimize"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def closure():\n",
    "    \n",
    "    global i\n",
    "           \n",
    "    out = net(net_input)[:, :, :imsize, :imsize]\n",
    "    \n",
    "    cnn(vgg_preprocess_var(out))\n",
    "    total_loss =  sum(matcher_content.losses.values())\n",
    "    total_loss.backward()\n",
    "    \n",
    "    print ('Iteration %05d    Loss %.3f' % (i, total_loss.item()), '\\r', end='')\n",
    "    if PLOT and i % 200 == 0:\n",
    "        out_np = np.clip(torch_to_np(out), 0, 1)\n",
    "        plot_image_grid([out_np], 3, 3);\n",
    "\n",
    "    i += 1\n",
    "    \n",
    "    return total_loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "i=0\n",
    "matcher_content.mode = 'match'\n",
    "p = get_params(OPT_OVER, net, net_input)\n",
    "optimize(OPTIMIZER, p, closure, LR, num_iter)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "out = net(net_input)[:, :, :imsize, :imsize]\n",
    "plot_image_grid([torch_to_np(out)], 3, 3);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The code above was used to produce the images from the paper."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Appedndix: more noise"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We also found adding heavy noise sometimes improves the results (see below). Interestingly, network manages to adapt to a very heavy noise."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "input_depth = 2\n",
    "net_input = get_noise(input_depth, INPUT, imsize_net).type(dtype).detach()\n",
    "\n",
    "net = skip(input_depth, 3, num_channels_down = [16, 32, 64, 128, 128, 128],\n",
    "                           num_channels_up =   [16, 32, 64, 128, 128, 128],\n",
    "                           num_channels_skip = [4, 4, 4, 4, 4, 4],   \n",
    "                           filter_size_up = [7, 7, 5, 5, 3, 3], filter_size_down = [7, 7, 5, 5, 3, 3],\n",
    "                           upsample_mode='nearest', downsample_mode='avg',\n",
    "                           need_sigmoid=True, need_bias=True, pad=pad, act_fun='LeakyReLU').type(dtype)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def closure():\n",
    "    \n",
    "    global i    \n",
    "    if i < 10000:\n",
    "        # Weight noise\n",
    "        for n in [x for x in net.parameters() if len(x) == 4]:\n",
    "            n = n + n.detach().clone().normal_()*n.std()/50\n",
    "        \n",
    "        # Input noise\n",
    "        net_input = net_input_saved + (noise.normal_() * 10)\n",
    "\n",
    "    elif i < 15000:\n",
    "        # Weight noise\n",
    "        for n in [x for x in net.parameters() if len(x) == 4]:\n",
    "            n = n + n.detach().clone().normal_()*n.std()/100\n",
    "        \n",
    "        # Input noise\n",
    "        net_input = net_input_saved + (noise.normal_() * 2)\n",
    "        \n",
    "    elif i < 20000:\n",
    "        # Input noise\n",
    "        net_input = net_input_saved + (noise.normal_() / 2)\n",
    "    \n",
    "    \n",
    "    out = net(net_input)[:, :, :imsize, :imsize]\n",
    "    \n",
    "    cnn(vgg_preprocess_var(out))\n",
    "    total_loss =  sum(matcher_content.losses.values())\n",
    "    total_loss.backward()\n",
    "    \n",
    "    print ('Iteration %05d    Loss %.3f' % (i, total_loss.item()), '\\r', end='')\n",
    "    if PLOT and i % 1000==0:\n",
    "        out_np = np.clip(torch_to_np(out), 0, 1)\n",
    "        plot_image_grid([out_np], 3, 3);\n",
    "\n",
    "    i += 1\n",
    "    \n",
    "    return total_loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_iter = 20000\n",
    "LR = 0.01\n",
    "\n",
    "net_input_saved = net_input.detach().clone()\n",
    "noise = net_input.detach().clone()\n",
    "i=0\n",
    "\n",
    "matcher_content.mode = 'match'\n",
    "p = get_params(OPT_OVER, net, net_input)\n",
    "optimize(OPTIMIZER, p, closure, LR, num_iter)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
